import copy

import networkx as nx
import matplotlib.pyplot as plt

from ModularUtils.ControllerConstants import generate_permutations
from ModularUtils.FunctionsConstant import getdoKey


def set_trainGraph(noise_states, latent_state, obs_state, Data_intervs):
    DAG_desc = "trainGraph"

    Complete_DAG_desc = "trainGraph"
    Complete_DAG = {}
    plot_title="Modular Training distribution convergence"

    Observed_DAG = {}

    Observed_DAG["X0"] = []
    Observed_DAG["X1"] = []
    Observed_DAG["X2"] = ["X1"]
    Observed_DAG["W1"] = ["X0", "X1", "X2"]
    Observed_DAG["Y1"] = ["W1"]
    Observed_DAG["W0"] = ["X0"]
    Observed_DAG["Y0"] = ["W0", "W1"]



    num_confounders= 3
    Complete_DAG = {}
    for conf in range(num_confounders):
        Complete_DAG["U"+str(conf)] = []

    latent_conf={}
    for var in Observed_DAG:
        Complete_DAG[var]=[]
        latent_conf[var] = []

    confTochild = {"U0": ["X0", "Y0"], "U1": ["X1", "X2"], "U2": ["X2", "Y1"]}

    for conf in confTochild:
        for var in confTochild[conf]:
            latent_conf[var].append(conf)
            Complete_DAG[var].append(conf)

    for var in Observed_DAG:
        Complete_DAG[var]=Complete_DAG[var]+ Observed_DAG[var]


    # draw_true_graph(Complete_DAG)
    # draw_true_graph(Observed_DAG)


    complete_labels = list(Complete_DAG.keys())



    label_names = list(Observed_DAG.keys())



    # Complete_DAG["U0"] = []
    # Complete_DAG["U1"] = []
    # Complete_DAG["U2"] = []
    # Complete_DAG["X0"] = ["U0"]
    # Complete_DAG["X1"] = ["U1"]
    # Complete_DAG["X2"] = ["U1", "U2", "X1"]
    # Complete_DAG["W0"] = ["X0"]
    # Complete_DAG["W1"] = ["X0", "X1", "X2"]
    # Complete_DAG["Y0"] = ["U0", "W0", "W1"]
    # Complete_DAG["Y1"] = ["U2", "W1"]
    #
    #
    # # draw_true_graph(Complete_DAG)
    #
    #
    # complete_labels = list(Complete_DAG.keys())
    # draw_true_graph(Observed_DAG)

    image_labels= []
    rep_labels=[]

    label_dim = {}
    for label in Observed_DAG.keys():
        label_dim[label] =  obs_state


    for conf in confTochild:
        label_dim[conf] = latent_state

    intervention_list = [
                        # {"expr":"P(W_0|do(X_0))", "obs":['W0'], "inter_vars":['X0']},
                         # {"expr": "P(W_1|do(X_0,X_1,X_2))", "obs": ['W1'], "inter_vars": ['X0','X1','X2']},
                         {"expr": "P(X_0,W_1,W_0,Y_0|do(X_1,X_2))", "obs": ['X0', 'W1', 'W0','Y0'], "inter_vars": ['X1','X2']},
                         {"expr": "P(X_1,X_2,W_1,Y_1|do(X_0))", "obs": ['X1','X2','W1','Y1'], "inter_vars": ['X0']},
                         {"expr": "P(V)", "obs": ['X0', 'X1','X2','W1','Y1', 'W0', 'Y0'], "inter_vars": []}
                         ]

    for lid in range(len(intervention_list)):
        intervention_list[lid]["expr"] = getdoKey(intervention_list[lid]["obs"], intervention_list[lid]["inter_vars"])  #set math expression for future exp

    interv_queries = []
    for intervention in intervention_list:
        perms = generate_permutations([label_dim[lb] for lb in intervention["inter_vars"]])
        key_val = [dict(zip(intervention["inter_vars"], comb)) for comb in perms]
        interv_queries.append({"obs": intervention["obs"], "intervs": key_val, "expr": intervention["expr"]})


    # intervention_list = [{"obs": ["Y0"], "inter_vars": ["X0"]}, {"obs": ["Y1"], "inter_vars": ["W1"]}]

    # for lid in range(len(intervention_list)):
    #     intervention_list[lid]["expr"] = getdoKey(intervention_list[lid]["obs"], intervention_list[lid]["inter_vars"])
    #
    # interv_queries = []
    # for intervention in intervention_list:
    #     perms = generate_permutations([label_dim[lb]["feature"] for lb in intervention["inter_vars"]])
    #     key_val = [dict(zip(intervention["inter_vars"], comb)) for comb in perms]
    #     interv_queries.append({"obs": intervention["obs"], "intervs": key_val, "expr": intervention["expr"]})


    cf_queries = []

    exogenous = {}
    for label in label_names:
        if label not in image_labels:
            exogenous[label] = "n" + label



    # counterfactual variables
    cflabel_names = []
    Twin_Network = {}

    cf_exogenous = {}

    cf_intervene = {}
    cf_observe = []
    cf_evidence = {}

    twin_map = {}


    # noise_params = {}
    # for label in Observed_DAG:
    #     noise_params["n" + label] = (0.5, noise_states)
    # noise_params["U0"] = (0.1, latent_state)
    # noise_params["U1"] = (0.1, latent_state)
    # noise_params["U2"] = (0.1, latent_state)

    noise_params = {}
    for label in Observed_DAG:
        noise_params["n" + label] = (0.5, noise_states)

    for conf in confTochild:
        noise_params[conf] = (0.1, latent_state)


    train_mech_dict={}
    # for dist in Data_intervs:
    #     comp_dict= build_compares(confTochild, Observed_DAG, label_names, list(dist.keys()))
    #     for label in label_names:
    #         if label not in train_mech_dict:
    #             train_mech_dict[label]=[]
    #
    #         mech_dict = {"parents": Observed_DAG[label], "intv": dist, "compare":comp_dict[label]}
    #         if label in image_labels:
    #             continue
    #         train_mech_dict[label].append(mech_dict)


    train_mech_dict["W0"] = [{'parents': ['X0'], 'intv': {}, 'compare': ['W0']}]
    train_mech_dict["W1"] =  [{'parents': ['X0', 'X1', 'X2'], 'intv': {}, 'compare': ['W1']}]

    train_mech_dict["X0"] = [{'parents': ['X1','X2'], 'intv': {}, 'compare': ['X0','W1', 'W0', 'Y0']}]
    train_mech_dict["Y0"] = [{'parents': ['X1','X2'], 'intv': {}, 'compare': ['X0', 'W1', 'W0',  'Y0']}]

    train_mech_dict["X1"] = [{'parents': ['X0'], 'intv': {}, 'compare': ['X1', 'X2', 'W1', 'Y1']}]
    train_mech_dict["X2"] = [{'parents': ['X0'], 'intv': {}, 'compare': ['X1', 'X2', 'W1', 'Y1']}]
    train_mech_dict["Y1"] = [{'parents': ['X0'], 'intv': {}, 'compare': ['X1', 'X2', 'W1', 'Y1']}]


    # just for full training
    for lb in train_mech_dict:
        train_mech_dict[lb]=[{'parents': [], 'intv': {}, 'compare': label_names}]



    print("printing")
    for label in label_names:
        print(label, train_mech_dict[label])

    for label in Observed_DAG:
        if label not in image_labels:
            label_dim["n" + label] = noise_states

    return DAG_desc, Complete_DAG_desc, Complete_DAG, complete_labels, Observed_DAG, label_names, image_labels, rep_labels, interv_queries, cf_queries, latent_conf, \
           confTochild, exogenous, cf_intervene, cf_observe, cf_evidence, cflabel_names, twin_map, Twin_Network, cf_exogenous, \
           noise_params, train_mech_dict, label_dim, plot_title



